{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import itertools\n",
    "import sys\n",
    "if \"../\" not in sys.path:\n",
    "  sys.path.append(\"../\") \n",
    "from lib.utils import read_data_from_file, sign"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def weighted_pocket_pla(X, y, updates=200, weight=[1000, 1]):\n",
    "    \"\"\"\n",
    "    Weighted Pocket Perceptron Learning Algorithm\n",
    "    Args:\n",
    "        X: 数据\n",
    "        Y: 标签\n",
    "        updates: 更新次数\n",
    "        weight: weight[0]->FP weight[1]->FN\n",
    "    \n",
    "    Returns:\n",
    "        w_pocket: 最优特征权重\n",
    "        w: 最后更新得到的特征权重\n",
    "    \"\"\"\n",
    "    \n",
    "    def cal_prob(mistakes):\n",
    "        if len(mistakes) == 0:\n",
    "            return 0, 0, 0\n",
    "        \n",
    "        # cal prob distribution\n",
    "        FP_index = np.where(y[mistakes]==-1)[0]\n",
    "        FN_index = np.where(y[mistakes]== 1)[0]\n",
    "        prob = np.ones_like(mistakes) / (len(FP_index)*weight[0] + len(FN_index)*weight[1]) \n",
    "        prob[FP_index] *= weight[0]\n",
    "        prob[FN_index] *= weight[1]\n",
    "        np.testing.assert_array_almost_equal(np.sum(prob), 1.0, decimal=5)\n",
    "        return len(FP_index), len(FN_index), prob\n",
    "    \n",
    "    N = len(X)\n",
    "    w = np.zeros_like(X[0])\n",
    "    w_pocket = w\n",
    "    mistakes = np.where(sign(X.dot(w)) != y)[0]         # get index of mistakes\n",
    "    \n",
    "    FP, FN, prob = cal_prob(mistakes)                   # cal weight info\n",
    "    mis_pocket = FP*weight[0] + FN*weight[1]\n",
    "    \n",
    "    for i in range(updates):\n",
    "        mistake = np.random.choice(mistakes, p=prob)            # pike up one mistake randomly\n",
    "        w = w + X[mistake]*y[mistake]\n",
    "        mistakes = np.where(sign(X.dot(w)) != y)[0]\n",
    "        \n",
    "        FP, FN, prob = cal_prob(mistakes)\n",
    "        if mis_pocket > FP*weight[0] + FN*weight[1]: \n",
    "            w_pocket = w\n",
    "            mis_pocket = FP*weight[0] + FN*weight[1]\n",
    "    return w_pocket, w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_err_rate(X, y, w):\n",
    "    mistakes = np.where(sign(X.dot(w)) != y)[0]\n",
    "    err_rate = len(mistakes) / len(X)\n",
    "    return err_rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_all_err_example(X, y, w):\n",
    "    mistakes = np.where(sign(X.dot(w)) != y)[0]\n",
    "    return np.hstack((X[mistakes],Y[mistakes].reshape(-1,1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data_train shape:  (500, 5)\n",
      "data_test shape:  (500, 5)\n"
     ]
    }
   ],
   "source": [
    "data_train = read_data_from_file('hw1_18_train.dat') \n",
    "print('data_train shape: ', data_train.shape)\n",
    "data_test = read_data_from_file('hw1_18_test.dat')\n",
    "print('data_test shape: ', data_test.shape)\n",
    "sign = np.frompyfunc(sign, 1, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y = data_train[:,-1]\n",
    "X = np.concatenate((np.ones((data_train.shape[0],1)), data_train[:,:-1]), axis=1)\n",
    "\n",
    "w_pocket, _ = weighted_pocket_pla(X,Y, weight=[1, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 14  38]\n",
      " [  7  68]\n",
      " [  0 186]\n",
      " [  0 176]\n",
      " [  0 188]\n",
      " [  0 185]]\n"
     ]
    }
   ],
   "source": [
    "errs = get_all_err_example(X, Y, w_pocket)\n",
    "\n",
    "result = []\n",
    "for i in [1, 5, 10, 50, 100, 1000]:\n",
    "    w_pocket, _ = weighted_pocket_pla(X,Y, weight=[i*2, 1])\n",
    "    errs = get_all_err_example(X, Y, w_pocket)\n",
    "    result.append((len(np.where(errs[:,-1] < 0)[0]), len(np.where(errs[:,-1] > 0)[0])))\n",
    "print(np.array(result))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
